package com.lagou.homework.filter;

import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.Bucket4j;
import io.github.bucket4j.Refill;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/*
    ip 防爆刷
    每分钟只能请求30次
 */
@Slf4j
@RefreshScope
@Component
public class ProtectIPGatewayFilter implements GlobalFilter, Ordered {

    @Value("${getway.capacity}")
    private int capacity;//桶的最大容量，即能装载 Token 的最大数量
    @Value("${getway.refillTokens}")
    private int refillTokens; //每次 Token 补充量
    @Value("${getway.second}")
    private int second;

    private static final Map<String, Bucket> BUCKET_CACHE = new ConcurrentHashMap<>();

    private Bucket createNewBucket(){
        Duration duration = Duration.ofSeconds(second); //补充 Token 的时间间隔
        Refill refill = Refill.greedy(refillTokens, duration);
        Bandwidth limit = Bandwidth.classic(capacity, refill);
        return Bucket4j.builder().addLimit(limit).build();
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();
        // 当前 ip
        String ip = request.getRemoteAddress().getAddress().getHostAddress();
        Bucket bucket = BUCKET_CACHE.computeIfAbsent(ip, k -> createNewBucket());
        System.out.println("IP: " + ip + "，has Tokens: " + bucket.getAvailableTokens());
        if (bucket.tryConsume(1)){
            return chain.filter(exchange);
        }else {
            response.getHeaders().set("Content-Type","application/json;charset=UTF-8");
            response.setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
            String data = "错误码：303，错误信息：您频繁进⾏注册，请求已被拒绝";
            DataBuffer wrap = response.bufferFactory().wrap(data.getBytes(StandardCharsets.UTF_8));
            return response.writeWith(Mono.just(wrap));
        }
    }

    @Override
    public int getOrder() {
        return 0;
    }
}
